# Functions

# Plot color palette
plot_color_palette <- function(input_cols) {
  
  col_data <- tibble(color = input_cols) %>%
    mutate(color =  fct_inorder(color))
  
  res <- col_data %>%
    ggplot(aes(x = "color", fill = color)) +
    geom_bar() +
    scale_fill_manual(values = input_cols) +
    theme_void()
  
  res
}

# Filter list of Seurat objects for patient, normalize and merge objects 
merge_sobj <- function(sobj_list, sample_order = NULL) {

  res <- merge(
    x = sobj_list[[1]],
    y = sobj_list[2:length(sobj_list)],
    add.cell.ids = names(sobj_list)
  ) %>%
    ScaleData(assay = "RNA") %>%
    ScaleData(assay = "adt") %>%
    FindVariableFeatures(assay = "RNA")
  
  # Set sample order
  res@meta.data <- res@meta.data %>%
    rownames_to_column("cell_ids") %>%
    mutate(orig.ident = fct_relevel(orig.ident, sample_order)) %>%
    column_to_rownames("cell_ids")
  
  res
}

# Run PCA, cluster, and run UMAP using gene expression data
cluster_RNA <- function(sobj_in, assay = "RNA", resolution = 0.6, 
                        dims = 1:40, prefix = "", ...) {
  # Use FindNeighbors to construct a K-nearest neighbors graph based on the euclidean distance in 
  # PCA space, and refine the edge weights between any two cells based on the
  # shared overlap in their local neighborhoods (Jaccard similarity).
  # Use FindClusters to apply modularity optimization techniques such as the Louvain algorithm 
  # (default) or SLM, to iteratively group cells together
  
  # Perform PCA
  # By default only variable features are used for PCA
  res <- sobj_in %>%
    RunPCA(assay = assay, ...) %>%
    AddMetaData(
      metadata = FetchData(., c("PC_1", "PC_2")),
      col.name = str_c(prefix, c("PC_1", "PC_2"))
    )
    
  # Create nearest neighbors graph and find clusters
  res <- res %>%
    FindNeighbors(
      assay     = assay,
      reduction = "pca",
      dims      = dims
    ) %>%
    FindClusters(
      resolution = resolution,
      verbose    = F
    ) %>%
    AddMetaData(
      metadata = Idents(.),
      col.name = str_c(assay, "_clusters")
    )
  
  # Run UMAP, UMAP coordinates will get added to the meta.data by clustifyr
  res <- res %>%
    RunUMAP(
      assay = assay,
      dims  = dims,
      reduction.name = str_c(prefix, "umap"),
      reduction.key  = str_c(prefix, "UMAP_")
    )
  
  res
}

# Fit gaussian mixture model for given signal
fit_GMM <- function(sobj_in, data_column = "adt_ovalbumin") {
  
  # Fit GMM for OVA signal
  data_df <- sobj_in %>%
    FetchData(data_column)
  
  mixmdl <- data_df %>%
    pull(data_column) %>%
    normalmixEM()
  
  # New column names
  ova_names <- c("low", "high")
  comp_names <- c("comp.1", "comp.2")
  
  if (mixmdl$mu[1] > mixmdl$mu[2]) {
    ova_names <- rev(ova_names)
  }
  
  names(comp_names)    <- ova_names
  names(mixmdl$mu)     <- ova_names
  names(mixmdl$sigma)  <- ova_names
  names(mixmdl$lambda) <- ova_names

  # Divide into OVA groups
  res <- data.frame(
    cell_id = rownames(data_df),
    data    = data_df[, data_column],
    mixmdl$posterior
  ) %>%
    dplyr::rename(!!sym(data_column) := data) %>%
    rename(all_of(comp_names)) %>%
    mutate(GMM_grp = if_else(low > 0.5, "low", "high")) %>%
    column_to_rownames("cell_id")
  
  res <- list(
    res    = res,
    mu     = mixmdl$mu,
    sigma  = mixmdl$sigma,
    lambda = mixmdl$lambda
  )
  
  res
}

# Add distribution of GMM component to plot
add_stat_fun <- function(gmm_in, cols_in, key) {
  # dnorm provides density for the normal distribution given the mean and
  # standard deviation. Lambda is used to adjust for the mixture composition.
  # mu: Mean of component
  # sig: Standard deviation of component
  # lam: Lambda of component (mixture weight)
  plot_mix_comps <- function(x, mu, sigma, lam) {
    lam * dnorm(x, mu, sigma)
  }
  
  stat_function(
    geom  = "line",
    fun   = plot_mix_comps,
    args  = list(gmm_in$mu[key], gmm_in$sigma[key], gmm_in$lambda[key]),
    color = cols_in[key],
    lwd   = 1
  )
}

# Overlay feature data on UMAP or tSNE
# Cannot change number of columns when using FeaturePlot with split.by
plot_features <- function(sobj_in, x = "UMAP_1", y = "UMAP_2", feature, pt_size = 0.25,
                          split_id = NULL, plot_cols = c("#fafafa", "#e31a1c"),
                          feat_levels = NULL, split_levels = NULL, min_pct = NULL, 
                          max_pct = NULL, calc_cor = F, lab_size = 3.7, short_feat_name = T,
                          lab_pos = c(0.8, 0.9), lm_line = F, pt_outline = NULL, 
                          show_title = F, ...) {
  
  # Format imput data
  counts <- sobj_in
  short_feat <- feature
  
  if (short_feat_name) {
    short_feat <- feature %>%
      str_remove("\\-[A-Z][0-9]{4}$")
  }
  
  if ("Seurat" %in% class(sobj_in)) {
    vars <- c(x, y, feature)
    
    if (!is.null(split_id)) {
      vars <- c(vars, split_id)
    }

    counts <- sobj_in %>%
      FetchData(vars = vars) %>%
      as_tibble(rownames = "cell_ids")
  }
  
  counts <- counts %>%
    rename(!!sym(short_feat) := !!sym(feature))
  
  # Rename features
  if (!is.null(names(feature))) {
    names(short_feat) <- names(feature)
    
    counts <- counts %>%
      rename(!!!syms(short_feat))
    
    short_feat <- names(short_feat)
  }
  
  if (!is.null(names(x))) {
    counts <- counts %>%
      rename(!!!syms(x))
    
    x <- names(x)
  }
  
  if (!is.null(names(y))) {
    counts <- counts %>%
      renames(!!!syms(y))
    
    y <- names(y)
  }
  
  if (show_title) {    
    counts <- counts %>%
      gather(key, value, !!sym(short_feat))

    short_feat <- "value"
  }
  
  # Set min and max values for feature
  if (!is.null(min_pct) || !is.null(max_pct)) {
    counts <- counts %>%
      mutate(
        pct_rank = percent_rank(!!sym(short_feat)),
        max_val  = ifelse(pct_rank > max_pct, !!sym(short_feat), NA),
        max_val  = min(max_val, na.rm = T),
        min_val  = ifelse(pct_rank < min_pct, !!sym(short_feat), NA),
        min_val  = max(min_val, na.rm = T),
        value    = ifelse(!!sym(short_feat) > max_val, max_val, !!sym(short_feat)),
        value    = ifelse(!!sym(short_feat) < min_val, min_val, !!sym(short_feat))
      )
  }

  # Set feature order
  if (!is.null(feat_levels)) {
    counts <- counts %>%
      mutate(!!sym(short_feat) := fct_relevel(!!sym(short_feat), feat_levels))
  }
  
  # Set facet order
  if (!is.null(split_id)) {
    counts <- counts %>%
      rename(split_id = !!sym(split_id))
    
    if (!is.null(split_levels)) {
      counts <- counts %>%
        mutate(split_id = fct_relevel(split_id, split_levels))
    }
  }
  
  # Calculate correlation
  if (calc_cor) {
    if (!is.null(split_id)) {
      counts <- counts %>%
        group_by(!!sym(split_id))
    }
    
    counts <- counts %>%
      mutate(
        cor_lab = cor(!!sym(x), !!sym(y)),
        cor_lab = round(cor_lab, digits = 2),
        cor_lab = str_c("r = ", cor_lab),
        min_x   = min(!!sym(x)),
        max_x   = max(!!sym(x)),
        min_y   = min(!!sym(y)),
        max_y   = max(!!sym(y)),
        lab_x   = (max_x - min_x) * lab_pos[1] + min_x,
        lab_y   = (max_y - min_y) * lab_pos[1] + min_y
      )
  }
  
  # Create scatter plot
  res <- counts %>%
    arrange(!!sym(short_feat)) %>%
    ggplot(aes(!!sym(x), !!sym(y), color = !!sym(short_feat)))
  
  if (!is.null(pt_outline)) {
    # pt_out_legend <- T
    # 
    # if (is.numeric(counts[[short_feat]])) {
    #   pt_out_legend <- F 
    # }
    
    res <- res +
      geom_point(aes(fill = !!sym(short_feat)), size = pt_outline, color = "black", show.legend = F)
  }

  res <- res +
    geom_point(size = pt_size)
  
  # Add regression line
  if (lm_line) {
    res <- res +
      geom_smooth(method = "lm", se = F, color = "black", size = 0.5, linetype = 2)
  }
  
  # Add correlation coefficient label
  if (calc_cor) {
    res <- res +
      geom_text(
        aes(x = lab_x, lab_y, label = cor_lab),
        color = "black",
        size  = lab_size,
        check_overlap = T, 
        show.legend = F
      )
  }
  
  # Show facet-style title
  if (show_title) {
    res <- res +
      facet_wrap(~ key, scales = "free") +
      theme(legend.title = element_blank())
  }
  
  # Set feature colors
  if (is.numeric(counts[[short_feat]])) {
    res <- res +
      scale_color_gradient(low = plot_cols[1], high = plot_cols[2])

  } else {
    res <- res +
      scale_color_manual(values = plot_cols)
  }

  # Split plot into facets
  if (!is.null(split_id)) {
    res <- res +
      facet_wrap(~ split_id, ...)
  }
  
  res
}

# Run gprofiler
run_gprofiler <- function(gene_list, genome = NULL, gmt_id = NULL,
                          dbases = c("GO:BP", "GO:MF", "KEGG"), ...) {
  
  # Check for empty gene list
  if (is_empty(gene_list)) {
    return(NULL)
  }
  
  # Check arguments
  if (is.null(genome) && is.null(gmt_id)) {
    stop("ERROR: Must specifiy genome or gmt_id")
  }
  
  # Retrieve organism name for gProfileR
  if (!is.null(genome)){
    genomes <- list(
      GRCm = "mmusculus",
      GRCh = "hsapiens",
      BDGP = "dmelanogaster"
    )
    
    org <- genome %>% 
      str_remove("[0-9]+$") %>%
      genomes[[.]]
  }
  
  if (!is.null(gmt_id)) {
    org <- gmt_id
    dbases <- NULL
  }
  
  # Run gProfileR
  res <- gene_list %>%
    gost(
      organism      = org,
      sources       = dbases,
      domain_scope  = "annotated",
      significant   = T,
      ...
    )
    
  # Format and sort output data.frame
  res <- res$result %>%
    as_tibble() %>%
    arrange(source, p_value)
  
  res
}

# Create GO bubble plot
create_bubbles <- function(GO_df, plot_colors = theme_cols[c(1:2, 4, 9)],
                           n_terms = 15) {

  # Check for empty inputs
  if (is_empty(GO_df) || nrow(GO_df) == 0) {
    res <- ggplot() +
      geom_blank()
    
    return(res)
  }
  
  # Shorten GO terms and database names
  GO_data <- GO_df %>%
    mutate(
      term_id = str_remove(term_id, "(GO|KEGG):"),
      term_id = str_c(term_id, " ", term_name),
      term_id = str_to_lower(term_id),
      term_id = str_trunc(term_id, 40, "right"),
      source  = fct_recode(
        source,
        "Biological\nProcess" = "GO:BP",
        "Cellular\nComponent" = "GO:CC",
        "Molecular\nFunction" = "GO:MF",
        "KEGG"                = "KEGG"
      )
    )
  
  # Reorder database names
  plot_levels <- c(
    "Biological\nProcess",
    "Cellular\nComponent",
    "Molecular\nFunction",
    "KEGG"
  )
  
  GO_data <- GO_data %>%
    mutate(source = fct_relevel(source, plot_levels))
  
  # Extract top terms for each database
  top_GO <- GO_data %>%
    group_by(source) %>%
    arrange(p_value) %>%
    dplyr::slice(1:n_terms) %>%
    ungroup()
  
  # Create bubble plots
  res <- GO_data %>%
    ggplot(aes(1.25, -log10(p_value), size = intersection_size)) +
    geom_point(color = plot_colors, alpha = 0.5, show.legend = T) +
    geom_text_repel(
      aes(2, -log10(p_value), label = term_id),
      data         = top_GO,
      size         = 2.3,
      direction    = "y",
      hjust        = 0,
      segment.size = NA
    ) +
    xlim(1, 8) +
    labs(y = "-log10(p-value)") +
    theme_info +
    theme(
      axis.title.x    = element_blank(),
      axis.text.x     = element_blank(),
      axis.ticks.x    = element_blank()
    ) +
    facet_wrap(~ source, scales = "free", nrow = 1)
  
  res
}

# Plot percentage of cells in given groups
plot_cell_count <- function(sobj_in, group_id, split_id = NULL, group_order = NULL,
                            fill_id, plot_colors = theme_cols,
                            x_lab = "Cell type", y_lab = "Fraction of cells",
                            bar_pos = "fill", order_count = T, bar_line = 0, ...) {
  
  res <- sobj_in@meta.data %>%
    rownames_to_column("cell_ids") %>%
    mutate(
      group_id := !!sym(group_id),
      fill_id  := !!sym(fill_id)
    )
  
  if (!is.null(group_order)) {
    res <- res %>%
      mutate(group_id = fct_relevel(group_id, group_order))
  }
  
  if (!is.null(split_id)) {
    res <- res %>%
      mutate(split_id := !!sym(split_id))
  }
  
  if (order_count) {
    res <- res %>%
      mutate(fill_id = fct_reorder(fill_id, cell_ids, n_distinct))
  }

  res <- res %>%
    ggplot(aes(group_id, fill = fill_id)) +
    geom_bar(position = bar_pos, size = bar_line, color = "black") +
    scale_fill_manual(values = plot_colors) +
    labs(x = x_lab, y = y_lab) +
    theme_info

  if (!is.null(split_id)) {
    res <- res +
      facet_wrap(~ split_id, ...)
  }
  
  res
}

# Plot confidence intervals for median
create_ci_boxes <- function(input_sobj, group_column, data_column, box_cols, pseudo = 0.03, ...) {
  
  # Create data.frame with confidence intervals
  get_boots <- function(data_in, conf = c(0.9, 0.95, 0.99), R = 10000, ...) {
    
    get_ci <- function(conf, boot_in, ...) {
      
      res <- boot.ci(
        boot.out = boot_in,
        conf     = conf, 
        type     = "basic", 
        ...
      )
      
      res <- tibble(
        med   = res$t0,
        conf  = str_c(conf * 100, "%"),
        lower = res$basic[4],
        upper = res$basic[5]
      )
      
      res
    }
    
    boot_obj <- boot(
      data = data_in,
      statistic = function(x, i) median(x[i]),
      R = R
    )
    
    names(conf) <- conf
    
    res <- conf %>%
      map(get_ci, boot_obj) %>%
      bind_rows()
    
    res
  }
  
  box_data <- input_sobj %>%
    FetchData(c(group_column, data_column)) %>%
    as_tibble(rownames = "cell_id") %>%
    mutate(
      grp = !!sym(group_column),
      grp = fct_reorder(grp, !!sym(data_column), median),
      !!sym(data_column) := !!sym(data_column) + pseudo
    )
  
  conf_df <- box_data %>%
    group_by(grp) %>%
    summarize(boot_res = list(get_boots(!!sym(data_column)))) %>%
    ungroup() %>%
    unnest(cols = boot_res) %>%
    pivot_longer(cols = c(-grp, -conf), names_to = "key", values_to = "vals") %>%
    mutate(vals = if_else(vals < 0, pseudo, vals)) %>%
    pivot_wider(names_from = key, values_from = vals)
  
  # Create scaled error bars
  conf_sizes <- c(
    `90%` = 4,
    `95%` = 3,
    `99%` = 2
  )
  
  conf_alphas <- c(
    `90%` = 1,
    `95%` = 0.5,
    `99%` = 0.25
  )
  
  res <- conf_df %>%
    ggplot(aes(med, grp, color = grp)) +
    geom_violin(data = box_data, aes(!!sym(data_column), grp), fill = "#f0f0f0", color = "#f0f0f0", size = 0.2) +
    geom_errorbarh(aes(xmin = lower, xmax = upper, alpha = conf, size = conf), height = 0) +
    geom_point(shape = 22, size = 1, fill = "white") +
    scale_color_manual(values = box_cols, guide = F) +
    scale_alpha_manual(values = conf_alphas, guide = F) +
    scale_size_manual(
      name   = "Confidence Level",
      values = conf_sizes
    )
  
  res <- res +
    scale_x_log10() +
    labs(x = data_column) +
    theme_info +
    theme(
      legend.position = "top",
      legend.title    = element_text(size = 10),
      legend.text     = element_text(size = 10),
      axis.title.y    = element_blank(),
      axis.title.x    = element_text(size = 10),
      axis.text       = element_text(size = 10)
    )
  
  res
  
  # geom_errorbarh(aes(xmin = lower, xmax = upper, alpha = conf), height = 0, size = 3) +
  # guide  = guide_legend(direction = "horizontal", title.position = "top", label.position = "bottom")
}

# Run FindAllMarkers
find_markers <- function(input_sobj, only_pos = T, p_cutoff = 0.05, ...) {
  res <- input_sobj %>%
    FindAllMarkers(only.pos = only_pos, ...) %>%
    as_tibble() %>%
    filter(p_val_adj < p_cutoff)
  
  res
}

# Find cluster markers for each separate cell type
find_group_markers <- function(input_grp, input_sobj, grp_column, clust_column) {
  
  res <- input_sobj %>%
    subset(!!sym(grp_column) == input_grp)

  clusts <- res@meta.data[, clust_column]

  if (n_distinct(clusts) < 2) {
    return(NULL)
  }

  Idents(res) <- res %>%
    FetchData(clust_column)

  res <- res %>%
    find_markers() %>%
    mutate(cell_type = input_grp)

  res
}

# Create reference UMAP for comparisons
create_ref_umap <- function(input_sobj, pt_mtplyr = 1, color_guide, ...) {
  res <- input_sobj %>%
    plot_features(
      pt_size     = 0.1 * pt_mtplyr,
      pt_outline  = 0.4,
      ...
    ) +
    guides(color = color_guide) +
    blank_theme +
    theme(
      legend.position = "top",
      legend.title    = element_blank(),
      legend.text     = element_text(size = 10)
    )
  
  res
}

# Create UMAPs showing marker gene signal
create_marker_umaps <- function(input_sobj, input_markers, umap_col, add_outline = NULL, pt_mtplyr = 1) {
  
  pt_size <- 0.25 * pt_mtplyr
  
  res <- input_markers %>%
    map(~ {
      input_sobj %>%
        plot_features(
          feature    = .x, 
          plot_cols  = c("#fafafa", umap_col),
          pt_outline = add_outline,
          pt_size    = pt_size
        ) +
        ggtitle(.x) +
        blank_theme +
        theme(
          plot.title        = element_text(size = 13),
          legend.position   = "bottom",
          legend.title      = element_blank(),
          legend.text       = element_text(size = 8),
          legend.key.height = unit(0.1, "cm"),
          legend.key.width  = unit(0.3, "cm"),
          axis.title.y      = element_text(size = 13, color = "white"),
          axis.text.y       = element_text(size = 8, color = "white")
        )
    })
  
  res
}

# Create boxplots showing marker gene signal
create_marker_boxes <- function(input_sobj, input_markers, clust_column, box_cols,
                                group = NULL, include_legend = F, all_boxes = F,
                                all_violins = F, order_boxes = T, n_boxes = 10,
                                n_rows = 2, pt_mtplyr = 1, ...) {
  
  # Retrieve and format data for boxplots
  box_data <- input_sobj %>%
    FetchData(c(clust_column, input_markers)) %>%
    as_tibble(rownames = "cell_id") %>%
    mutate(grp = str_remove(!!sym(clust_column), "^[a-zA-Z0-9_]+-"))
  
  input_markers <- input_markers %>%
    str_trunc(9)
  
  # Filter based on input group
  if (!is.null(group)) {
    box_data <- box_data %>%
      filter(grp == group)
  }
  
  # Format data for plots
  box_data <- box_data %>%
    pivot_longer(cols = c(-cell_id, -grp, -!!sym(clust_column)), names_to = "key", values_to = "Counts") %>%
    mutate(
      !!sym(clust_column) := fct_relevel(!!sym(clust_column), names(box_cols)),
      key = str_trunc(key, width = 9, side = "right"),
      key = fct_relevel(key, input_markers)
    )
  
  # Order boxes by mean signal
  if (order_boxes) {
    box_data <- box_data %>%
      mutate(!!sym(clust_column) := fct_reorder(!!sym(clust_column), Counts, mean, .desc = T))
  }
  
  n_clust <- box_data %>%
    pull(clust_column) %>%
    n_distinct()
  
  # Create plots
  n_cols <- ceiling(n_boxes / n_rows)
  
  res <- box_data %>%
    ggplot(aes(!!sym(clust_column), Counts, color = !!sym(clust_column))) + 
    facet_wrap(~ key, ncol = n_cols) +
    scale_color_manual(values = box_cols) +
    theme_info +
    theme(
      panel.spacing.x  = unit(0.7, "cm"),
      strip.background = element_blank(),
      strip.text       = element_text(size = 13),
      legend.position  = "none",
      axis.title.x     = element_blank(),
      axis.title.y     = element_text(size = 13),
      axis.text.x      = element_blank(),
      axis.text.y      = element_text(size = 8),
      axis.ticks.x     = element_blank(),
      axis.line.x      = element_blank()
    )
  
  # Adjust output plot type
  if (n_clust > 6 || all_boxes) {
    res <- res +
      geom_boxplot(aes(fill = !!sym(clust_column), color = !!sym(clust_column)), size = 0, outlier.color = "#f0f0f0", outlier.size = 0.25) +
      stat_summary(fun = "median", geom = "point", shape = 22, size = 1, fill = "white") +
      scale_fill_manual(values = box_cols) +
      theme(...)
    
  } else if (all_violins) {
    res <- res +
      geom_violin(aes(fill = !!sym(clust_column)), size = 0.2) +
      stat_summary(fun = "median", geom = "point", shape = 22, size = 1, fill = "white") +
      scale_fill_manual(values = box_cols) +
      scale_color_manual(values = box_cols) +
      theme(...)
    
  } else {
    pt_size <- 0.3 * pt_mtplyr
    
    res <- res +
      geom_quasirandom(size = pt_size) +
      theme(...)
  }
  
  # Add legend
  if (include_legend) {
    res <- res +
      guides(color = col_guide) +
      theme(legend.position = "top")
  }
  
  # Add blank space for missing facets
  n_keys <- n_distinct(box_data$key)
  
  if (n_keys <= n_cols && n_rows > 1) {
    n_keys <- if_else(n_keys == 1, 2, as.double(n_keys))
    n_cols <- floor(n_cols / n_keys)
    
    res <- res %>%
      plot_grid(
        ncol = n_cols,
        nrow = 2
      )
  }
  
  res
}

# Create figure summarizing marker genes
create_marker_fig <- function(input_sobj, input_markers, input_GO, clust_column, 
                              input_umap, umap_color, fig_heights = c(0.46, 0.3, 0.3), 
                              GO_genome = params$genome, box_colors, n_boxes = 10,
                              umap_outline = NULL, umap_mtplyr = 1, xlsx_name = NULL, 
                              sheet_name = NULL, ...) {
  
  blank_umap <- ggplot() +
    geom_blank() +
    theme_void()
  
  marks_umap  <- blank_umap
  marks_boxes <- blank_umap
  GO_bubbles  <- blank_umap
  
  # Create UMAPs showing marker gene signal
  if (nrow(input_markers) > 0) {
    top_marks <- input_markers$gene %>%
      head(n_boxes)
    
    clust_legend <- get_legend(input_umap)
    
    input_umap <- input_umap +
      theme(legend.position = "none")
    
    marks_umap <- input_sobj %>%
      create_marker_umaps(
        input_markers = head(top_marks, 7),
        umap_col      = umap_color,
        add_outline   = umap_outline,
        pt_mtplyr     = umap_mtplyr
      ) %>%
      append(list(input_umap), .)
    
    marks_umap <- plot_grid(
      plotlist = marks_umap,
      ncol     = 4,
      nrow     = 2,
      align    = "vh",
      axis     = "trbl"
    )
    
    marks_umap <- plot_grid(
      clust_legend, marks_umap,
      rel_heights = c(0.2, 0.9),
      nrow = 2
    )
    
    # Create boxplots showing marker gene signal
    marks_boxes <- input_sobj %>%
      create_marker_boxes(
        input_markers = top_marks,
        clust_column  = clust_column,
        box_cols      = box_colors,
        n_boxes       = n_boxes,
        plot.margin   = unit(c(0.8, 0.2, 0.2, 0.2), "cm"),
        ...
      )
    
    # Create GO term plots
    if (nrow(input_GO) > 0) {
      GO_bubbles <- input_GO %>%
        create_bubbles(plot_colors = umap_color) +
        theme(
          plot.margin      = unit(c(0.8, 0.2, 0.2, 0.2), "cm"),
          strip.background = element_blank(),
          strip.text       = element_text(size = 13),
          axis.title.y     = element_text(size = 13),
          axis.text.y      = element_text(size = 8),
          axis.line.x      = element_blank(),
          legend.position  = "bottom",
          legend.title     = element_blank(),
          legend.text      = element_text(size = 8)
        )
      
      # Write GO terms to excel file 
      if (!is.null(xlsx_name)) {
        input_GO %>%
          dplyr::select(
            term_name,  term_id,
            source,     effective_domain_size,
            query_size, intersection_size,
            p_value,    significant 
          ) %>%
          arrange(source, p_value) %>%
          write.xlsx(
            file      = str_c(xlsx_name, "_GO.xlsx"),
            sheetName = sheet_name,
            append    = T
          )
      }
    }
    
    # Write markers to excel file
    if (!is.null(xlsx_name)) {
      input_markers %>%
        write.xlsx(
          file      = str_c(xlsx_name, "_markers.xlsx"),
          sheetName = sheet_name,
          append    = T
        )
    }
  }
  
  # Create final figure
  res <- plot_grid(
    marks_umap, marks_boxes, GO_bubbles,
    rel_heights = fig_heights,
    ncol        = 1,
    align       = "v",
    axis        = "rl"
  )
  
  if (nrow(input_markers) < n_boxes) {
    res <- plot_grid(
      marks_umap, marks_boxes, GO_bubbles,
      rel_heights = fig_heights,
      ncol        = 1
    )
  }
  
  res
}

# Filter clusters and set cluster order
set_cluster_order <- function(input_cols, input_marks, n_cutoff = 5) {
  input_marks <- input_marks %>%
    group_by(cluster) %>%
    filter(n() >= n_cutoff) %>%
    ungroup()
  
  marks <- unique(input_marks$cluster)
  res   <- names(input_cols)
  res   <- res[res %in% marks]
  
  res
}
  
# Create v1 panel for marker genes
create_marker_panel_v1 <- function(input_sobj, input_cols, input_umap = NULL, clust_column, order_boxes = T,
                                   color_guide = guide_legend(override.aes = list(size = 3.5, shape = 16)),
                                   uniq_GO = F, umap_mtplyr = 6, xlsx_name = NULL, ...) {
  
  # Set point size
  umap_mtplyr <- if_else(ncol(input_sobj) < 500, umap_mtplyr, 1)
  ref_mtplyr <- if_else(umap_mtplyr == 1, umap_mtplyr, umap_mtplyr * 2.5)
  
  # Find marker genes
  Idents(input_sobj) <- input_sobj %>%
    FetchData(clust_column)
  
  markers <- find_markers(input_sobj)
  
  # Find GO terms
  GO_df <- markers %>%
    group_by(cluster) %>%
    do({
      arrange(., p_val_adj) %>%
        pull(gene) %>%
        run_gprofiler(
          genome = params$genome,
          ordered_query = T
        )
    }) %>%
    ungroup()
  
  if (uniq_GO && nrow(GO_df) > 0) {
    GO_df <- GO_df %>%
      group_by(term_id) %>%
      filter(n() == 1) %>%
      ungroup()
  }
  
  # Set cluster order based on order of input_cols
  fig_clusters <- input_cols %>%
    set_cluster_order(markers)
  
  # Create figures
  for (i in seq_along(fig_clusters)) {
    cat("\n#### ", fig_clusters[i], "\n", sep = "")
    
    # Filter markers and GO terms
    clust <- fig_clusters[i]
    
    fig_marks <- markers %>%
      filter(cluster == clust)
    
    fig_GO <- GO_df %>%
      filter(cluster == clust)
    
    # Create reference umap
    ref_umap <- input_umap
    umap_col <- input_cols[clust]
    
    if (is.null(input_umap)) {
      umap_levels <- input_cols[names(input_cols) != clust]
      umap_levels <- names(c(umap_levels, umap_col))
      
      ref_umap <- input_sobj %>%
        create_ref_umap(
          feature     = clust_column,
          plot_cols   = input_cols,
          feat_levels = umap_levels,
          pt_mtplyr   = ref_mtplyr,
          color_guide = color_guide
        )
    }
    
    # Create panel
    marker_fig <- input_sobj %>%
      create_marker_fig(
        input_markers = fig_marks,
        input_GO      = fig_GO,
        clust_column  = clust_column,
        input_umap    = ref_umap,
        umap_color    = umap_col,
        box_colors    = input_cols,
        order_boxes   = order_boxes,
        umap_mtplyr   = umap_mtplyr,
        xlsx_name     = xlsx_name,
        sheet_name    = clust,
        ...
      )
    
    cat(nrow(fig_marks), "marker genes were identified,", nrow(fig_GO), "GO terms were identified.")
    print(marker_fig)
    cat("\n\n---\n\n<br>\n\n<br>\n\n")
  }
}

# Create v2 panel that splits plots into groups
create_marker_panel_v2 <- function(input_sobj, input_markers, input_cols, grp_column, clust_column, 
                                   color_guide = guide_legend(override.aes = list(size = 3.5, shape = 16)), 
                                   uniq_GO = F, umap_mtplyr = 6, xlsx_name = NULL, ...) {
  
  # Set point size
  umap_mtplyr <- if_else(ncol(input_sobj) < 500, umap_mtplyr, 1)
  ref_mtplyr <- if_else(ncol(input_sobj) < 500, umap_mtplyr * 2.5, 1)
  
  # Figure colors and order
  fig_clusters <- input_cols %>%
    set_cluster_order(input_markers)
  
  # Find GO terms
  GO_df <- input_markers %>%
    group_by(cluster) %>%
    do({
      arrange(., p_val_adj) %>%
        pull(gene) %>%
        run_gprofiler(
          genome = params$genome,
          ordered_query = T
        )
    }) %>%
    ungroup()
  
  if (uniq_GO && nrow(GO_df) > 0) {
    GO_df <- GO_df %>%
      group_by(term_id) %>%
      filter(n() == 1) %>%
      ungroup()
  }
  
  # Create figures
  for (i in seq_along(fig_clusters)) {
    cat("\n#### ", fig_clusters[i], "\n", sep = "")
    
    # Filter markers and GO terms
    clust <- fig_clusters[i]
    
    fig_marks <- input_markers %>%
      filter(cluster == clust)
    
    fig_GO <- GO_df %>%
      filter(cluster == clust)
    
    # Set colors
    umap_col <- input_cols[clust]
    
    group <- clust %>%
      str_remove("^[a-zA-Z0-9_]+-")
    
    grp_regex <- str_c("-", group, "$") %>%
      str_replace("\\+", "\\\\+")            # include this to escape "+" in names
    
    fig_cols <- input_cols[grepl(grp_regex, names(input_cols))]
    fig_cols <- c( "Other" = "#fafafa", fig_cols)
    
    # Create reference UMAP
    ref_umap <- input_sobj %>%
      FetchData(c("UMAP_1", "UMAP_2", grp_column, clust_column)) %>%
      as_tibble(rownames = "cell_id") %>%
      mutate(!!sym(clust_column) := if_else(
        !!sym(grp_column) != group, 
        "Other", 
        !!sym(clust_column)
      )) %>%
      create_ref_umap(
        feature     = clust_column,
        plot_cols   = fig_cols,
        feat_levels = names(fig_cols),
        pt_mtplyr   = ref_mtplyr,
        color_guide = color_guide
      )
    
    # Create panel
    marker_fig <- input_sobj %>%
      create_marker_fig(
        input_markers = fig_marks,
        input_GO      = fig_GO,
        clust_column  = clust_column,
        input_umap    = ref_umap,
        umap_color    = umap_col,
        box_colors    = fig_cols,
        group         = group,
        umap_mtplyr   = umap_mtplyr,
        xlsx_name     = xlsx_name,
        sheet_name    = clust,
        ...
      )
    
    cat(nrow(fig_marks), "marker genes were identified.", nrow(fig_GO), "GO terms were identified.")
    print(marker_fig)
    cat("\n\n---\n\n<br>\n\n<br>\n\n")
  }
}

# Create panels for manuscript
create_paper_figures <- function(input_sobj, input_cols, summary_fig = NULL, input_umap = NULL, clust_column,
                                 color_guide = guide_legend(override.aes = list(size = 3.5, shape = 16)),
                                 order_boxes = T, uniq_GO = F, umap_mtplyr = 6, xlsx_name = NULL, ...) {
  
  # Set point size
  umap_mtplyr <- if_else(ncol(input_sobj) < 500, umap_mtplyr, 1)
  ref_mtplyr <- if_else(umap_mtplyr == 1, umap_mtplyr, umap_mtplyr * 2.5)
  
  # Find marker genes
  Idents(input_sobj) <- input_sobj %>%
    FetchData(clust_column)
  
  markers <- find_markers(input_sobj)
  
  # Find GO terms
  GO_df <- markers %>%
    group_by(cluster) %>%
    do({
      arrange(., p_val_adj) %>%
        pull(gene) %>%
        run_gprofiler(
          genome = params$genome,
          ordered_query = T
        )
    }) %>%
    ungroup()
  
  if (uniq_GO && nrow(GO_df) > 0) {
    GO_df <- GO_df %>%
      group_by(term_id) %>%
      filter(n() == 1) %>%
      ungroup()
  }
  
  # Set cluster order based on order of input_cols
  fig_clusters <- input_cols %>%
    set_cluster_order(markers)
  
  # Create figures
  for (i in seq_along(fig_clusters)) {
    cat("\n#### ", fig_clusters[i], "\n", sep = "")
    
    # Filter markers and GO terms
    clust <- fig_clusters[i]
    
    fig_marks <- markers %>%
      filter(cluster == clust)
    
    fig_GO <- GO_df %>%
      filter(cluster == clust)
    
    # Create reference umap
    ref_umap <- input_umap
    umap_col <- input_cols[clust]
    
    if (is.null(input_umap)) {
      umap_levels <- input_cols[names(input_cols) != clust]
      umap_levels <- names(c(umap_levels, umap_col))
      
      ref_umap <- input_sobj %>%
        create_ref_umap(
          feature     = clust_column,
          plot_cols   = input_cols,
          feat_levels = umap_levels,
          pt_mtplyr   = ref_mtplyr,
          color_guide = color_guide
        )
    }
    
    # Create panel
    marker_fig <- input_sobj %>%
      create_marker_fig(
        input_markers = fig_marks,
        input_GO      = fig_GO,
        clust_column  = clust_column,
        input_umap    = ref_umap,
        umap_color    = umap_col,
        box_colors    = input_cols,
        order_boxes   = order_boxes,
        umap_mtplyr   = umap_mtplyr,
        xlsx_name     = xlsx_name,
        sheet_name    = clust,
        ...
      )
    
    if (!is.null(summary_fig)) {
      marker_fig <- plot_grid(
        summary_fig, marker_fig,
        rel_heights = c(0.3, 0.7),
        ncol  = 1,
        align = "vh",
        axis  = "trbl"
      )
    }
    
    print(marker_fig)
    cat("\n\n---\n\n<br>\n\n<br>\n\n")
  }
}


# Default chunk options
knitr::opts_chunk$set(message = F, warning = F)

# Load packages
R_packages <- c(
  "tidyverse",  "Seurat",
  "gprofiler2", "knitr",
  "cowplot",    "ggbeeswarm",
  "ggrepel",    "RColorBrewer",
  "xlsx",       "colorblindr",
  "ggforce",    "broom",
  "mixtools",   "clustifyr",
  "boot",       "scales"
)

for (package in R_packages) {
  library(package, character.only = T)
}


# ggplot2 themes
theme_info <- theme_cowplot() +
  theme(
    plot.title       = element_text(face = "plain", size = 20),
    strip.background = element_rect(fill = "#fafafa"),
    strip.text       = element_text(face = "plain")
  )

umap_theme <- theme_info +
  theme(
    axis.text  = element_blank(),
    axis.ticks = element_blank()
  )

blank_theme <- umap_theme +
  theme(
    axis.line  = element_blank(),
    axis.title = element_blank()
  )

# Legend guides
col_guide <- guide_legend(override.aes = list(size = 3.5, shape = 16))

outline_guide <- guide_legend(override.aes = list(
  size   = 3.5,
  shape  = 21,
  color  = "black",
  stroke = 0.25
))

# Base color palettes
base_cols <- c(
  "#225ea8",  # blue
  "#e31a1c",  # red
  "#238443",  # green
  "#ec7014",  # orange
  "#6a51a3",  # purple
  "#c51b7d",  # pink
  "#8c510a",  # brown
  "#217D87",  # teal, darken("#41b6c4", 0.3)
  "#F0E442",  # yellow, palette_OkabeIto[4]
  "#000000"   # black
)

base_cols_paired <- base_cols %>%
  map(~ {
    .x %>%
        lighten(0.25) %>%
        desaturate(0.2) %>%
        c(.x)
    })

names(base_cols_paired) <- base_cols

base_cols <- base_cols %>%
  lighten(0.25) %>%
  desaturate(0.2) %>%
  c(base_cols, .)

# Okabe Ito color palettes
ito_cols <- c(
  palette_OkabeIto[1:4], "#d7301f", 
  palette_OkabeIto[5:6], "#6a51a3", 
  palette_OkabeIto[7:8]
)

ito_cols_paired <- ito_cols %>%
  map(~ c(.x, darken(.x, 0.3)))

names(ito_cols_paired) <- ito_cols 

ito_cols <- ito_cols %>%
  darken(0.4) %>% 
  c(ito_cols, ., "#000000")

# Set color palette
theme_cols <- base_cols
paired_cols <- base_cols_paired

theme_cols <- ito_cols
paired_cols <- ito_cols_paired
# Function to add one to variable
plus_one <- function(x, n = 1) {
  cmd <- str_c(x, " <<- ", x, " + ", n)
  
  eval(parse(text = cmd))
  eval(parse(text = x))
}

# Function to set DC colors
get_DC_cols <- function(types_in, cols_in, other_cols) {
  
  types_in <- types_in[!types_in %in% names(other_cols)]
  cols_in <- cols_in[!cols_in %in% other_cols]
  
  names(cols_in) <- types_in
  cols_in <- cols_in[!is.na(names(cols_in))]
  
  res <- c(cols_in, other_cols)
  
  res
}

# Function to subset Seurat objects for plotting
subset_sobj <- function(sobj_path, type, type_column = "cell_type1",
                        include_types = c("B cell", "T cell", "epithelial"), ...) {

  # Import Seurat object
  res <- sobj_path %>%
    file.path(params$rds_dir, .) %>%
    read_rds()
  
  # Add CD45 status to meta.data
  res@meta.data <- res@meta.data %>%
    rownames_to_column("cell_id") %>%
    mutate(orig.ident = if_else(grepl("_1$", cell_id), "CD45_neg", "CD45_pos")) %>%
    column_to_rownames("cell_id")
  
  # Filter based on CD45 status
  CD45_status <- c(
    DC  = "CD45_pos",
    LEC = "CD45_neg",
    fibroblast = "CD45_neg"
  )
  
  CD45_status <- CD45_status[type]
  
  if (is.na(CD45_status)) {
    stop(str_c("ERROR: CD45 status not found for ", type))
  }
  
  # Filter based on input cell type
  res <- res %>%
    subset(subset = orig.ident == CD45_status)
  
  if (!type %in% pull(res@meta.data, type_column)) {
    stop(str_c("ERROR: Cell type not present after filtering for ", CD45_status))
  }
  
  # Re-run UMAP
  res <- res %>%
    subset(subset = !!sym(type_column) %in% c(type, include_types)) %>%
    cluster_RNA(...)
  
  res
}

# Function to add arrows to axis
add_arrow_axis <- function(gg_in, fract = 0.1, ...) {
  
  get_line_coords <- function(range_in, fract) {
    mn  <- range_in[1]
    mx  <- range_in[2]
    dif <- mx - mn

    res <- c(mn - (dif * 0.05), mn + (dif * fract))

    res
  }
  
  x_coords <- ggplot_build(gg_in)$layout$panel_scales_x[[1]]$range$range %>%
    get_line_coords(fract = fract)

  y_coords <- ggplot_build(gg_in)$layout$panel_scales_y[[1]]$range$range %>%
    get_line_coords(fract = fract)
   
  res <- gg_in +
    geom_segment(
      x        = x_coords[1],
      xend     = x_coords[2], 
      y        = y_coords[1], 
      yend     = y_coords[1],
      size     = 0.25,
      color    = "black",
      linejoin = "bevel",
      arrow    = arrow(ends = "last", type = "open", length = unit(0.02, "npc")),
      ...
    ) +
    geom_segment(
      y        = y_coords[1], 
      yend     = y_coords[2], 
      x        = x_coords[1], 
      xend     = x_coords[1], 
      size     = 0.25,
      color    = "black",
      linejoin = "bevel",
      arrow    = arrow(ends = "last", type = "open", length = unit(0.02, "npc")),
      ...
    ) +
    theme(axis.title = element_text(hjust = 0, size = 10))
  
  res
}

# Function to create figure panel
create_fig <- function(sobj_in, cols_in, type_column = "cell_type2",
                       data_column = c("Fold-change relative to T/B cell abundance (log10)" = "ova_fc"), 
                       title = NULL, gg_out = NULL, legd_rows = 10, arrow_axis = F, 
                       pt_size = 0.1, pt_outline = 0.4, cell_counts = T, data_umap = F,
                       other_cols = c("B cell" = "#E69F00", "T cell" = "#009E73"),
                       ...) {
  
  if (!is.null(gg_out) && !exists(gg_out)) {
    cmd <- str_c(gg_out, " <<- NULL")
    eval(parse(text = cmd))
  }
  
  # Set subtype colors
  if (is.null(names(cols_in))) {
    
    cell_types <- sobj_in@meta.data %>%
      pull(type_column) %>%
      unique()
    
    cols_in <- get_DC_cols(
      types_in   = cell_types,
      cols_in    = cols_in, 
      other_cols = other_cols
    )
  }
  
  # Set type order
  box_data <- sobj_in@meta.data %>%
    as_tibble(rownames = "cell_id") %>%
    mutate(
      cell_type = !!sym(type_column),
      cell_type = fct_reorder(cell_type, !!sym(data_column), median)
    )
  
  type_order <- box_data %>%
    pull(cell_type) %>%
    sort() %>%
    unique()
  
  # Add cell counts
  umap_data   <- sobj_in
  umap_cols   <- cols_in
  umap_column <- type_column
  umap_order  <- type_order
  
  if (cell_counts) {
    umap_data <- umap_data %>%
      FetchData(c(type_column, data_column, "UMAP_1", "UMAP_2")) %>%
      as_tibble(rownames = "cell_id") %>%
      mutate(cell_type = !!sym(type_column)) %>%
      group_by(cell_type) %>%
      mutate(cell_count = n_distinct(cell_id)) %>%
      ungroup() %>%
      mutate(count_lab = str_c(cell_type, "\n(n = ", cell_count, ")"))
    
    umap_cols <- umap_data %>%
      select(cell_type, count_lab) %>%
      unique() %>%
      mutate(color = umap_cols[cell_type])
    
    umap_cols$cell_type <- factor(umap_cols$cell_type, umap_order)
    
    umap_order <- umap_cols %>%
      arrange(cell_type) %>%
      pull(count_lab) %>%
      rev()
    
    umap_cols <- umap_cols %>%
      (function(x) setNames(x$color, x$count_lab))
    
    umap_column <- "count_lab"
  }
  
  # Subtype UMAP
  umap <- umap_data %>%
    plot_features(
      feature     = umap_column,
      pt_size     = pt_size,
      pt_outline  = pt_outline,
      plot_cols   = umap_cols,
      feat_levels = umap_order
    ) +
    guides(color = guide_legend(override.aes = list(size = 3.5), nrow = legd_rows, title = title)) +
    ggtitle(title) +
    blank_theme +
    theme(legend.text = element_text(size = 10)) +
    theme(...)
  
  if (arrow_axis) {
    umap <- add_arrow_axis(umap)
  }
  
  # OVA UMAP
  if (data_umap) {
    ova_umap <- umap_data %>%
      plot_features(
        feature    = data_column,
        pt_size    = 0.3,
        pt_outline = 0.5
      ) +
      blank_theme +
      theme(
        legend.title = element_blank(),
        legend.text  = element_text(size = 10),
        plot.margin  = unit(c(0.8, 0.5, 0, 2.1), "cm"),
        legend.position = "bottom",
        legend.key.height = unit(0.3, "cm"),
        
        # axis.line = element_line(size = 0.5),
        # legend.key.width  = unit(0.3, "cm")
      )
    
    if (arrow_axis) {
      ova_umap <- add_arrow_axis(ova_umap)
    }
  }
  
  # OVA boxes
  boxes <- box_data %>%
    ggplot(aes(cell_type, !!sym(data_column), color = cell_type, fill = cell_type)) +
    geom_violin(size = 0.5) +
    geom_boxplot(
      fill          = NA, 
      size          = 0.5, 
      outlier.size  = 0.5, 
      outlier.color = NA, 
      ymin          = NA, 
      ymax          = NA, 
      fatten        = 1, 
      color         = "black", 
      width         = 0.35
    ) +
    scale_color_manual(values = cols_in) +
    scale_fill_manual(values = cols_in) +
    coord_flip() + 
    scale_y_log10(labels = trans_format("log10", math_format(10^.x))) +
    theme_minimal_vgrid() +
    theme(
      legend.position = "none",
      axis.title.y    = element_blank(),
      axis.title      = element_text(size = 10),
      axis.text       = element_text(size = 10),
      panel.grid.major.x = element_line(size = 0.15)
    )
  
  if (!is.null(names(data_column))) {
    boxes <- boxes +
      labs(y = names(data_column))
  }
  
  # Add plot to provided list
  if (!is.null(gg_out)) {
    umap_cmd <- str_c(gg_out, " <<- append(", gg_out, ", list('", title, "' = umap))")
    box_cmd <- str_c(gg_out, " <<- append(", gg_out, ", list(boxes))")
    
    eval(parse(text = umap_cmd))
    eval(parse(text = box_cmd))
    
    if (data_umap) {
      ova_cmd <- str_c(gg_out, " <<- append(", gg_out, ", list(ova_umap))")
      eval(parse(text = ova_cmd))
    }
    
    return(NULL)
  }
  
  res <- list(umap, boxes)
  
  if (data_umap) {
    res <- append(res, list(ova_umap))
  }
  
  res
}

# Function to create final figure
create_final_fig <- function(ggs_in, rel_width = c(1, 0.7)) {
  
  umap_list <- ggs_in[names(ggs_in) != ""]
  box_list <- ggs_in[names(ggs_in) == ""]
  
  box_list[2:length(box_list)] <- box_list[2:length(box_list)] %>%
    map(~ .x + theme(legend.position = "none"))
  
  umaps <- plot_grid(
    plotlist       = umap_list,
    ncol           = 1,
    align          = "v",
    labels         = letters[seq_along(umap_list)],
    label_y        = 1,
    label_size     = 22,
    label_fontface = "plain"
  )
  
  boxes <- plot_grid(
    plotlist = box_list,
    ncol     = 1,
    align    = "vh",
    axis     = "trbl"
  )
  
  res <- plot_grid(
    umaps, boxes,
    rel_widths = rel_width,
    ncol = 2
  )
  
  res
}

# Create list of Seurat objects
so_paths <- params$sobjs %>%
  map_chr(~ .x[1])

so_types <- params$sobjs %>%
  map_chr(~ .x[2])

so_names <- params$sobjs %>%
  map_chr(~ .x[3])

sobjs <- map2(so_paths, so_types, subset_sobj)
names(sobjs) <- so_names

# DC Seurat objects
DC_sobjs <- sobjs[so_types == "DC"]

DC_types <- DC_sobjs %>%
  map(~ unique(.x$cell_type2)) %>%
  reduce(c) %>%
  unique()

# Other Seurat objects
other_sobjs <- sobjs[so_types != "DC"]

# DC color palettes
ito_cols_2 <- get_DC_cols(
  types_in   = DC_types, 
  cols_in    = ito_cols, 
  other_cols = c(
    "B cell" = "#E69F00",
    "T cell" = "#009E73"
  )
)

T_B_cols <- c(
  "B cell" = "#E69F00",
  "T cell" = "#676767"
)

base_cols_2 <- get_DC_cols(
  types_in   = DC_types, 
  cols_in    = base_cols,
  other_cols = T_B_cols
)

n_plot <- 0